
from openai import OpenAI
import os
import json
from dotenv import load_dotenv
from tqdm import tqdm
import argparse

load_dotenv()


def llm_consensus_gen(text: list, api: str, verbose: bool = False):
    client = OpenAI(api_key = api)

    consensus_llm_prompt = f"""
    You are given 5 texts. Your task is to form/generate a consensus/agreement text using the given texts. Consensus or agreement would mean producing a new text that uses the given 5 texts to find a coherent text that includes words and information that is consistent across all the given texts.
    Text 1: {text[0]}
    Text 2: {text[1]}
    Text 3: {text[2]}
    Text 4: {text[3]}
    Text 5: {text[4]}

    Here are some important guidelines:
    - If the texts differ at a certain point/word, the consensus text should select the most frequent word from among the given texts at the point of difference.
    - If the texts differ at a certain point/word and there is no most frequent word, the consensus text should select the word that is most similar to the other words in the text.
    - Abstain if the texts are too different and no consensus can be reached.
   
    Strictly follow the guidelines above, especially regarding abstaining if the texts are too different.

    Return your generation in the following format. Do not include any other text:

    consensus text: [your consensus text here]

    """

    completion = client.chat.completions.create(
        model="gpt-4o-mini-2024-07-18",
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": consensus_llm_prompt},
        ],
    )

    try:
        consensus = completion.choices[0].message.content.strip().split("consensus text: ")[1]
    except:
        consensus = completion.choices[0].message.content.strip()

    if verbose:
        print(consensus)
    return consensus

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate LLM-based consensus from responses.")
    parser.add_argument('--input_dir', type=str, required=True, help="Input dir path to the input JSON files.")
    parser.add_argument('--output_dir', type=str, required=True, help="Output dir path to the output JSON files.")
    args = parser.parse_args()

    api = os.getenv("OPENAI_API_KEY")

    for filename in tqdm(os.listdir(args.input_dir)):
            file_path = os.path.join(args.input_dir, filename)

            with open(file_path, 'r') as json_file:
                data = json.load(json_file)  

            for topic in tqdm(data):
                consensus = llm_consensus_gen(topic["Responses"], api, verbose=False)
                topic["Responses"] = [consensus]

            res_file_path = os.path.join(args.output_dir, filename)
            with open(res_file_path, 'w') as json_file:
                json.dump(data, json_file, indent=4)
                
            print("Done ", filename)